use std::collections::{HashMap, HashSet};
use std::sync::OnceLock;

use goblin::pe::section_table::SectionTable;

// TODO: Check for more false positives
const FALSE_POSITIVES: [[u8; 32]; 2] = [
    [0x6F, 0x16, 0x80, 0x73, 0xB9, 0xB2, 0x14, 0x49, 0xD7, 0x42, 0x24, 0x17, 0x00, 0x06, 0x8A, 0xDA, 0xBC, 0x30, 0x6F, 0xA9, 0xAA, 0x38, 0x31, 0x16, 0x4D, 0xEE, 0x8D, 0xE3, 0x4E, 0x0E, 0xFB, 0xB0],
    [0x67, 0xE6, 0x09, 0x6A, 0x85, 0xAE, 0x67, 0xBB, 0x72, 0xF3, 0x6E, 0x3C, 0x3A, 0xF5, 0x4F, 0xA5, 0x7F, 0x52, 0x0E, 0x51, 0x8C, 0x68, 0x05, 0x9B, 0xAB, 0xD9, 0x83, 0x1F, 0x19, 0xCD, 0xE0, 0x5B]
];

struct Filter {
    offsets: HashMap<usize, &'static [u8; 8]>,
    locator: offset_finder::OffsetLocator<'static>,
}

static RESTRICTED_FILTER: OnceLock<Filter> = OnceLock::new();
static RELAXED_FILTER: OnceLock<Filter> = OnceLock::new();

fn get_restricted_filter() -> &'static Filter {
    RESTRICTED_FILTER.get_or_init(|| {
        let mut offsets: HashMap<usize, &'static [u8; 8]> = HashMap::new();
        offsets.insert(0, &[2, 9, 16, 23, 30, 37, 44, 51]);
        offsets.insert(1, &[3, 10, 17, 24, 35, 42, 49, 56]);
        offsets.insert(2, &[3, 14, 25, 32, 44, 51, 58, 65]);
        offsets.insert(3, &[3, 10, 21, 28, 35, 42, 49, 56]);
        offsets.insert(4, &[3, 10, 21, 28, 35, 42, 49, 56]);
        Filter {
            offsets,
            locator: offset_finder::OffsetLocator {
                name: "AES",
                partial_match: vec![
                    "c7 01 ?? ?? ?? ?? c7 41 04 ?? ?? ?? ?? c7 41 08 ?? ?? ?? ?? c7 41 0c ?? ?? ?? ?? c7 41 10 ?? ?? ?? ?? c7 41 14 ?? ?? ?? ?? c7 41 18 ?? ?? ?? ?? c7 41 1c ?? ?? ?? ?? c3",
                    "c7 45 d0 ?? ?? ?? ?? c7 45 d4 ?? ?? ?? ?? c7 45 d8 ?? ?? ?? ?? c7 45 dc ?? ?? ?? ?? 0f ?? ?? ?? c7 45 e0 ?? ?? ?? ?? c7 45 e4 ?? ?? ?? ?? c7 45 e8 ?? ?? ?? ?? c7 45 ec ?? ?? ?? ?? 0f",
                    "c7 45 d0 ?? ?? ?? ?? ?? ?? ?? ?? c7 45 d4 ?? ?? ?? ?? ?? ?? ?? ?? c7 45 d8 ?? ?? ?? ?? c7 45 dc ?? ?? ?? ?? ?? ?? ?? ?? ?? c7 45 e0 ?? ?? ?? ?? c7 45 e4 ?? ?? ?? ?? c7 45 e8 ?? ?? ?? ?? c7 45 ec ?? ?? ?? ??",
                    "c7 45 d0 ?? ?? ?? ?? c7 45 d4 ?? ?? ?? ?? ?? ?? ?? ?? c7 45 d8 ?? ?? ?? ?? c7 45 dc ?? ?? ?? ?? c7 45 e0 ?? ?? ?? ?? c7 45 e4 ?? ?? ?? ?? c7 45 e8 ?? ?? ?? ?? c7 45 ec ?? ?? ?? ??",
                    "c7 45 ?? ?? ?? ?? ?? c7 45 ?? ?? ?? ?? ?? ?? ?? ?? ?? c7 45 ?? ?? ?? ?? ?? c7 45 ?? ?? ?? ?? ?? c7 45 ?? ?? ?? ?? ?? c7 45 ?? ?? ?? ?? ?? c7 45 ?? ?? ?? ?? ?? c7 45 ?? ?? ?? ?? ??",
                ],
                full_match: "",
                skip_offset_print: false,
                allow_multiple_matches: true,
            },
        }
    })
}

fn get_relaxed_filter() -> &'static Filter {
    RELAXED_FILTER.get_or_init(|| {
        let mut offsets: HashMap<usize, &'static [u8; 8]> = HashMap::new();
        offsets.insert(0, &[3, 10, 17, 24, 35, 42, 49, 56]);
        offsets.insert(1, &[2, 9, 16, 23, 30, 37, 44, 51]);
        offsets.insert(2, &[3, 10, 21, 28, 35, 42, 49, 56]);
        Filter {
            offsets,
            locator: offset_finder::OffsetLocator {
                name: "AES",
                partial_match: vec![
                    "c7 ?? ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ??",
                    "c7 ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ??",
                    "c7 ?? ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ?? c7 ?? ?? ?? ?? ?? ??",
                ],
                full_match: "",
                skip_offset_print: false,
                allow_multiple_matches: true,
            },
        }
    })
}

pub fn dump_aes_key_restricted(image_base: usize,
                               sections: &[SectionTable],
                               data: &[u8]) -> Result<HashSet<Vec<u8>>, offset_finder::Error> {
    dump_aes_key_internal(image_base, sections, data, get_restricted_filter())
}

pub fn dump_aes_key(image_base: usize,
                    sections: &[SectionTable],
                    data: &[u8]) -> Result<HashSet<Vec<u8>>, offset_finder::Error> {
    dump_aes_key_internal(image_base, sections, data, get_relaxed_filter())
}

fn dump_aes_key_internal(image_base: usize,
                         sections: &[SectionTable],
                         data: &[u8],
                         filter: &Filter) -> Result<HashSet<Vec<u8>>, offset_finder::Error> {
    let results = filter.locator.find_all_partial_only(image_base, sections, data)?;
    // Probabilistic allocation, 50% or more will be false positives, so preallocate (n / 2) + 1
    let mut output: HashSet<Vec<u8>> = HashSet::with_capacity((results.len() / 2) + 1);
    for outer in results {
        let offset = *filter.offsets.get(&outer.0).unwrap();
        for inner in outer.1 {
            let mut key = Vec::with_capacity(32);
            for tmp in offset {
                let tmp = inner.0 + (*tmp as usize);
                key.extend_from_slice(&data[tmp..tmp + 4]);
            }
            let mut should_add = true;
            for false_positive in FALSE_POSITIVES {
                if false_positive.eq(&key[0..32]) {
                    should_add = false;
                    break;
                }
            }
            if should_add {
                output.insert(key);
            }
        }
    }
    Ok(output)
}